// Copyright © 2023-2024 Apple Inc.

#include <cassert>
#include <numeric>

#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/backend/cpu/lapack.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

namespace mlx::core {

namespace {

///////////////////////////////////////////////////////////////////////////////
// Naive reference conv
///////////////////////////////////////////////////////////////////////////////

template <typename T>
void slow_conv_1D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  auto& encoder = cpu::get_command_encoder(stream);
  encoder.set_input_array(in);
  encoder.set_input_array(wt);
  encoder.set_output_array(out);

  encoder.dispatch([start_wt_ptr = wt.data<T>(),
                    in_ptr = in.data<T>(),
                    out_ptr = out.data<T>(),

                    N = in.shape(
                        0), // Batch size, should be the same as out.shape(0)
                    iH = 1 +
                        in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
                    oH = out.shape(1), // Output spatial dim
                    wH = wt.shape(1), // Weight spatial dim
                    groups = in.shape(2) / wt.shape(2),
                    O = wt.shape(0), // Out channels
                    C_per_group = wt.shape(2),

                    in_stride_N = in.strides()[0],
                    in_stride_H = in.strides()[1],
                    in_stride_C = in.strides()[2],

                    wt_stride_O = wt.strides()[0],
                    wt_stride_H = wt.strides()[1],
                    wt_stride_C = wt.strides()[2],

                    out_stride_N = out.strides()[0],
                    out_stride_H = out.strides()[1],
                    out_stride_O = out.strides()[2],

                    flip,
                    padding_lo = padding_lo[0],
                    padding_hi = padding_hi[0],
                    wt_stride = wt_strides[0],
                    wt_dilation = wt_dilation[0],
                    in_dilation = in_dilation[0]]() mutable {
    auto O_per_group = O / groups;

    for (int n = 0; n < N; ++n) {
      for (int oh = 0; oh < oH; ++oh) {
        for (int g = 0; g < groups; ++g) {
          for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
            const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O;
            float r = 0.;

            for (int wh = 0; wh < wH; ++wh) {
              const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H;

              int wh_flip = flip ? (wH - wh - 1) : wh;
              int ih = oh * wt_stride - padding_lo + wh_flip * wt_dilation;

              auto ih_div = std::div(ih, in_dilation);

              if (ih >= 0 && ih < iH && ih_div.rem == 0) {
                for (int c = g * C_per_group; c < (g + 1) * C_per_group; ++c) {
                  r +=
                      static_cast<float>(
                          in_ptr[ih_div.quot * in_stride_H + c * in_stride_C]) *
                      static_cast<float>(
                          wt_ptr[(c % C_per_group) * wt_stride_C]);
                } // c

              } // ih check
            } // wh

            out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast<T>(r);
          } // o
        } // g
      } // oh

      in_ptr += in_stride_N;
      out_ptr += out_stride_N;
    } // n
  });
}

template <typename T>
void slow_conv_2D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  auto& encoder = cpu::get_command_encoder(stream);
  encoder.set_input_array(in);
  encoder.set_input_array(wt);
  encoder.set_output_array(out);

  encoder.dispatch(
      [st_wt_ptr = wt.data<T>(),
       st_in_ptr = in.data<T>(),
       st_out_ptr = out.data<T>(),

       N = in.shape(0), // Batch size, should be the same as out.shape(0)
       iH = 1 + in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
       iW = 1 + in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
       C = in.shape(3), // In channels
       oH = out.shape(1), // Output spatial dim
       oW = out.shape(2), // Output spatial dim
       O = wt.shape(0), // Out channels
       wH = wt.shape(1), // Weight spatial dim
       wW = wt.shape(2), // Weight spatial dim

       groups = in.shape(3) / wt.shape(3),
       C_per_group = wt.shape(3),

       in_stride_N = in.strides()[0],
       in_stride_H = in.strides()[1],
       in_stride_W = in.strides()[2],
       in_stride_C = in.strides()[3],

       wt_stride_O = wt.strides()[0],
       wt_stride_H = wt.strides()[1],
       wt_stride_W = wt.strides()[2],
       wt_stride_C = wt.strides()[3],

       out_stride_N = out.strides()[0],
       out_stride_H = out.strides()[1],
       out_stride_W = out.strides()[2],
       out_stride_O = out.strides()[3],

       padding_lo,
       padding_hi,
       wt_strides,
       wt_dilation,
       in_dilation,
       flip]() mutable {
        bool is_idil_one = in_dilation[0] == 1 && in_dilation[1] == 1;

        const int O_per_group = O / groups;
        auto pt_conv_no_checks =
            [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
              out_ptr += oh * out_stride_H + ow * out_stride_W;
              int ih_base = oh * wt_strides[0] - padding_lo[0];
              int iw_base = ow * wt_strides[1] - padding_lo[1];

              for (int g = 0; g < groups; ++g) {
                for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
                  float r = 0.;

                  for (int wh = 0; wh < wH; ++wh) {
                    for (int ww = 0; ww < wW; ++ww) {
                      int wh_flip = flip ? wH - wh - 1 : wh;
                      int ww_flip = flip ? wW - ww - 1 : ww;
                      int ih = ih_base + wh_flip * wt_dilation[0];
                      int iw = iw_base + ww_flip * wt_dilation[1];

                      const T* wt_ptr_pt =
                          wt_ptr + wh * wt_stride_H + ww * wt_stride_W;
                      const T* in_ptr_pt =
                          in_ptr + ih * in_stride_H + iw * in_stride_W;

                      for (int c = g * C_per_group; c < (g + 1) * C_per_group;
                           ++c) {
                        r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
                            static_cast<float>(
                                 wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
                      } // c
                    } // ww
                  } // wh

                  out_ptr[0] = static_cast<T>(r);
                  out_ptr += out_stride_O;
                  wt_ptr += wt_stride_O;
                } // o
              } // g
            };

        int jump_h = flip ? -wt_dilation[0] : wt_dilation[0];
        int jump_w = flip ? -wt_dilation[1] : wt_dilation[1];

        int init_h = (flip ? (wH - 1) * wt_dilation[0] : 0);
        int init_w = (flip ? (wW - 1) * wt_dilation[1] : 0);

        int f_wgt_jump_h =
            std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
        int f_wgt_jump_w =
            std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];

        int f_out_jump_h =
            std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
        int f_out_jump_w =
            std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];

        std::vector<int> base_h(f_out_jump_h);
        std::vector<int> base_w(f_out_jump_w);

        for (int i = 0; i < f_out_jump_h; ++i) {
          int ih_loop = i * wt_strides[0] - padding_lo[0] + init_h;

          int wh_base = 0;
          while (wh_base < wH && ih_loop % in_dilation[0] != 0) {
            wh_base++;
            ih_loop += jump_h;
          }

          base_h[i] = wh_base;
        }

        for (int j = 0; j < f_out_jump_w; ++j) {
          int iw_loop = j * wt_strides[1] - padding_lo[1] + init_w;

          int ww_base = 0;
          while (ww_base < wW && iw_loop % in_dilation[1] != 0) {
            ww_base++;
            iw_loop += jump_w;
          }

          base_w[j] = ww_base;
        }

        auto pt_conv_all_checks =
            [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) {
              out_ptr += oh * out_stride_H + ow * out_stride_W;

              int ih_base = oh * wt_strides[0] - padding_lo[0];
              int iw_base = ow * wt_strides[1] - padding_lo[1];

              int wh_base = base_h[oh % f_out_jump_h];
              int ww_base = base_w[ow % f_out_jump_w];

              for (int g = 0; g < groups; ++g) {
                for (int o = g * O_per_group; o < (g + 1) * O_per_group; ++o) {
                  float r = 0.;

                  for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
                    for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
                      int wh_flip = flip ? wH - wh - 1 : wh;
                      int ww_flip = flip ? wW - ww - 1 : ww;
                      int ih = ih_base + wh_flip * wt_dilation[0];
                      int iw = iw_base + ww_flip * wt_dilation[1];

                      if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) {
                        const T* wt_ptr_pt =
                            wt_ptr + wh * wt_stride_H + ww * wt_stride_W;

                        int ih_dil = !is_idil_one ? (ih / in_dilation[0]) : ih;
                        int iw_dil = !is_idil_one ? (iw / in_dilation[1]) : iw;

                        const T* in_ptr_pt = in_ptr + ih_dil * in_stride_H +
                            iw_dil * in_stride_W;

                        for (int c = g * C_per_group; c < (g + 1) * C_per_group;
                             ++c) {
                          r += static_cast<float>(in_ptr_pt[c * in_stride_C]) *
                              static_cast<float>(
                                   wt_ptr_pt[(c % C_per_group) * wt_stride_C]);
                        } // c

                      } // ih, iw check
                    } // ww
                  } // wh

                  out_ptr[0] = static_cast<T>(r);
                  out_ptr += out_stride_O;
                  wt_ptr += wt_stride_O;
                } // o
              } // g
            };

        int oH_border_0 = 0;
        int oH_border_1 = is_idil_one
            ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
            : oH;
        int oH_border_2 = std::max(
            oH_border_1,
            (iH + padding_lo[0] - wH * wt_dilation[0]) / wt_strides[0]);
        int oH_border_3 = oH;

        int oW_border_0 = 0;
        int oW_border_1 = is_idil_one
            ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
            : oW;
        int oW_border_2 = std::max(
            oW_border_1,
            (iW + padding_lo[1] - wW * wt_dilation[1]) / wt_strides[1]);
        int oW_border_3 = oW;

        for (int n = 0; n < N; ++n) {
          // Case 1: oh might put us out of bounds
          for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
            for (int ow = 0; ow < oW; ++ow) {
              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
            } // ow
          } // oh

          // Case 2: oh in bounds
          for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
            // Case a: ow might put us out of bounds
            for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
            } // ow

            // Case b: ow in bounds
            for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
              pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
            } // ow

            // Case c: ow might put us out of bounds
            for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
            } // ow

          } // oh

          // Case 3: oh might put us out of bounds
          for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
            for (int ow = 0; ow < oW; ++ow) {
              pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow);
            } // ow
          } // oh

          st_in_ptr += in_stride_N;
          st_out_ptr += out_stride_N;

        } // n
      });
}

template <typename T>
void slow_conv_3D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  auto& encoder = cpu::get_command_encoder(stream);
  encoder.set_input_array(in);
  encoder.set_input_array(wt);
  encoder.set_output_array(out);

  encoder.dispatch([st_wt_ptr = wt.data<T>(),
                    st_in_ptr = in.data<T>(),
                    st_out_ptr = out.data<T>(),

                    N = in.shape(
                        0), // Batch size, should be the same as out.shape(0)
                    iD = 1 +
                        in_dilation[0] * (in.shape(1) - 1), // Input spatial dim
                    iH = 1 +
                        in_dilation[1] * (in.shape(2) - 1), // Input spatial dim
                    iW = 1 +
                        in_dilation[2] * (in.shape(3) - 1), // Input spatial dim
                    oD = out.shape(1), // Output spatial dim
                    oH = out.shape(2), // Output spatial dim
                    oW = out.shape(3), // Output spatial dim
                    O = wt.shape(0), // Out channels
                    C = wt.shape(4), // In channels
                    wD = wt.shape(1), // Weight spatial dim
                    wH = wt.shape(2), // Weight spatial dim
                    wW = wt.shape(3), // Weight spatial dim

                    in_stride_N = in.strides()[0],
                    in_stride_D = in.strides()[1],
                    in_stride_H = in.strides()[2],
                    in_stride_W = in.strides()[3],
                    in_stride_C = in.strides()[4],

                    wt_stride_O = wt.strides()[0],
                    wt_stride_D = wt.strides()[1],
                    wt_stride_H = wt.strides()[2],
                    wt_stride_W = wt.strides()[3],
                    wt_stride_C = wt.strides()[4],

                    out_stride_N = out.strides()[0],
                    out_stride_D = out.strides()[1],
                    out_stride_H = out.strides()[2],
                    out_stride_W = out.strides()[3],
                    out_stride_O = out.strides()[4],
                    padding_lo,
                    padding_hi,
                    wt_strides,
                    wt_dilation,
                    in_dilation,
                    flip]() mutable {
    bool is_idil_one =
        in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1;

    auto pt_conv_no_checks = [&](const T* in_ptr,
                                 const T* wt_ptr,
                                 T* out_ptr,
                                 int od,
                                 int oh,
                                 int ow) {
      out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;
      int id_base = od * wt_strides[0] - padding_lo[0];
      int ih_base = oh * wt_strides[1] - padding_lo[1];
      int iw_base = ow * wt_strides[2] - padding_lo[2];

      for (int o = 0; o < O; ++o) {
        float r = 0.;

        for (int wd = 0; wd < wD; ++wd) {
          for (int wh = 0; wh < wH; ++wh) {
            for (int ww = 0; ww < wW; ++ww) {
              int wd_flip = flip ? wD - wd - 1 : wd;
              int wh_flip = flip ? wH - wh - 1 : wh;
              int ww_flip = flip ? wW - ww - 1 : ww;
              int id = id_base + wd_flip * wt_dilation[0];
              int ih = ih_base + wh_flip * wt_dilation[1];
              int iw = iw_base + ww_flip * wt_dilation[2];

              const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
                  wh * wt_stride_H + ww * wt_stride_W;
              const T* in_ptr_pt = in_ptr + id * in_stride_D +
                  ih * in_stride_H + iw * in_stride_W;

              for (int c = 0; c < C; ++c) {
                r += static_cast<float>(in_ptr_pt[0]) *
                    static_cast<float>(wt_ptr_pt[0]);
                in_ptr_pt += in_stride_C;
                wt_ptr_pt += wt_stride_C;
              } // c

            } // ww
          } // wh
        } // wd

        out_ptr[0] = static_cast<T>(r);
        out_ptr += out_stride_O;
        wt_ptr += wt_stride_O;
      } // o
    };

    int jump_d = flip ? -wt_dilation[0] : wt_dilation[0];
    int jump_h = flip ? -wt_dilation[1] : wt_dilation[1];
    int jump_w = flip ? -wt_dilation[2] : wt_dilation[2];

    int init_d = (flip ? (wD - 1) * wt_dilation[0] : 0);
    int init_h = (flip ? (wH - 1) * wt_dilation[1] : 0);
    int init_w = (flip ? (wW - 1) * wt_dilation[2] : 0);

    int f_wgt_jump_d =
        std::lcm(in_dilation[0], wt_dilation[0]) / wt_dilation[0];
    int f_wgt_jump_h =
        std::lcm(in_dilation[1], wt_dilation[1]) / wt_dilation[1];
    int f_wgt_jump_w =
        std::lcm(in_dilation[2], wt_dilation[2]) / wt_dilation[2];

    int f_out_jump_d = std::lcm(in_dilation[0], wt_strides[0]) / wt_strides[0];
    int f_out_jump_h = std::lcm(in_dilation[1], wt_strides[1]) / wt_strides[1];
    int f_out_jump_w = std::lcm(in_dilation[2], wt_strides[2]) / wt_strides[2];

    std::vector<int> base_d(f_out_jump_d);
    std::vector<int> base_h(f_out_jump_h);
    std::vector<int> base_w(f_out_jump_w);

    for (int i = 0; i < f_out_jump_d; ++i) {
      int id_loop = i * wt_strides[0] - padding_lo[0] + init_d;

      int wd_base = 0;
      while (wd_base < wD && id_loop % in_dilation[0] != 0) {
        wd_base++;
        id_loop += jump_d;
      }

      base_d[i] = wd_base;
    }

    for (int i = 0; i < f_out_jump_h; ++i) {
      int ih_loop = i * wt_strides[1] - padding_lo[1] + init_h;

      int wh_base = 0;
      while (wh_base < wH && ih_loop % in_dilation[1] != 0) {
        wh_base++;
        ih_loop += jump_h;
      }

      base_h[i] = wh_base;
    }

    for (int j = 0; j < f_out_jump_w; ++j) {
      int iw_loop = j * wt_strides[2] - padding_lo[2] + init_w;

      int ww_base = 0;
      while (ww_base < wW && iw_loop % in_dilation[2] != 0) {
        ww_base++;
        iw_loop += jump_w;
      }

      base_w[j] = ww_base;
    }

    auto pt_conv_all_checks = [&](const T* in_ptr,
                                  const T* wt_ptr,
                                  T* out_ptr,
                                  int od,
                                  int oh,
                                  int ow) {
      out_ptr += od * out_stride_D + oh * out_stride_H + ow * out_stride_W;

      int id_base = od * wt_strides[0] - padding_lo[0];
      int ih_base = oh * wt_strides[1] - padding_lo[1];
      int iw_base = ow * wt_strides[2] - padding_lo[2];

      int wd_base = base_d[od % f_out_jump_d];
      int wh_base = base_h[oh % f_out_jump_h];
      int ww_base = base_w[ow % f_out_jump_w];

      for (int o = 0; o < O; ++o) {
        float r = 0.;

        for (int wd = wd_base; wd < wD; wd += f_wgt_jump_d) {
          for (int wh = wh_base; wh < wH; wh += f_wgt_jump_h) {
            for (int ww = ww_base; ww < wW; ww += f_wgt_jump_w) {
              int wd_flip = flip ? wD - wd - 1 : wd;
              int wh_flip = flip ? wH - wh - 1 : wh;
              int ww_flip = flip ? wW - ww - 1 : ww;
              int id = id_base + wd_flip * wt_dilation[0];
              int ih = ih_base + wh_flip * wt_dilation[1];
              int iw = iw_base + ww_flip * wt_dilation[2];

              if (id >= 0 && id < iD && ih >= 0 && ih < iH && iw >= 0 &&
                  iw < iW) {
                const T* wt_ptr_pt = wt_ptr + wd * wt_stride_D +
                    wh * wt_stride_H + ww * wt_stride_W;

                int id_dil = !is_idil_one ? (id / in_dilation[0]) : id;
                int ih_dil = !is_idil_one ? (ih / in_dilation[1]) : ih;
                int iw_dil = !is_idil_one ? (iw / in_dilation[2]) : iw;

                const T* in_ptr_pt = in_ptr + id_dil * in_stride_D +
                    ih_dil * in_stride_H + iw_dil * in_stride_W;

                for (int c = 0; c < C; ++c) {
                  r += static_cast<float>(in_ptr_pt[0]) *
                      static_cast<float>(wt_ptr_pt[0]);
                  in_ptr_pt += in_stride_C;
                  wt_ptr_pt += wt_stride_C;
                } // c

              } // iD, ih, iw check
            } // ww
          } // wh
        } // wd

        out_ptr[0] = static_cast<T>(r);
        out_ptr += out_stride_O;
        wt_ptr += wt_stride_O;
      } // o
    };

    int oD_border_0 = 0;
    int oD_border_1 = is_idil_one
        ? ((padding_lo[0] + wt_strides[0] - 1) / wt_strides[0])
        : oD;
    int oD_border_2 = std::max(
        oD_border_1,
        (iD + padding_lo[0] - wD * wt_dilation[0]) / wt_strides[0]);
    int oD_border_3 = oD;

    int oH_border_0 = 0;
    int oH_border_1 = is_idil_one
        ? ((padding_lo[1] + wt_strides[1] - 1) / wt_strides[1])
        : oH;
    int oH_border_2 = std::max(
        oH_border_1,
        (iH + padding_lo[1] - wH * wt_dilation[1]) / wt_strides[1]);
    int oH_border_3 = oH;

    int oW_border_0 = 0;
    int oW_border_1 = is_idil_one
        ? ((padding_lo[2] + wt_strides[2] - 1) / wt_strides[2])
        : oW;
    int oW_border_2 = std::max(
        oW_border_1,
        (iW + padding_lo[2] - wW * wt_dilation[2]) / wt_strides[2]);
    int oW_border_3 = oW;

    for (int n = 0; n < N; ++n) {
      // Case 1: od might put us out of bounds
      for (int od = oD_border_0; od < oD_border_1; ++od) {
        for (int oh = 0; oh < oH; ++oh) {
          for (int ow = 0; ow < oW; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow
        } // oh
      } // od

      // Case 2: od in bounds
      for (int od = oD_border_1; od < oD_border_2; ++od) {
        // Case 2.1: oh might put us out of bounds
        for (int oh = oH_border_0; oh < oH_border_1; ++oh) {
          for (int ow = 0; ow < oW; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow
        } // oh

        // Case 2.2: oh in bounds
        for (int oh = oH_border_1; oh < oH_border_2; ++oh) {
          // Case 2.2.1: ow might put us out of bounds
          for (int ow = oW_border_0; ow < oW_border_1; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow

          // Case 2.2.2: ow in bounds
          for (int ow = oW_border_1; ow < oW_border_2; ++ow) {
            pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow

          // Case 2.2.3: ow might put us out of bounds
          for (int ow = oW_border_2; ow < oW_border_3; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow
        } // oh

        // Case 2.3: oh might put us out of bounds
        for (int oh = oH_border_2; oh < oH_border_3; ++oh) {
          for (int ow = 0; ow < oW; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow
        } // oh
      } // od

      // Case 3: od might put us out of bounds
      for (int od = oD_border_2; od < oD_border_3; ++od) {
        for (int oh = 0; oh < oH; ++oh) {
          for (int ow = 0; ow < oW; ++ow) {
            pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, od, oh, ow);
          } // ow
        } // oh
      } // od

      st_in_ptr += in_stride_N;
      st_out_ptr += out_stride_N;

    } // n
  });
}

void dispatch_slow_conv_1D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  if (in.dtype() == float32) {
    return slow_conv_1D<float>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == float16) {
    return slow_conv_1D<float16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == bfloat16) {
    return slow_conv_1D<bfloat16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else {
    throw std::invalid_argument(
        "[Convolution::eval] got unsupported data type.");
  }
}

void dispatch_slow_conv_2D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  if (in.dtype() == float32) {
    return slow_conv_2D<float>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == float16) {
    return slow_conv_2D<float16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == bfloat16) {
    return slow_conv_2D<bfloat16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else {
    throw std::invalid_argument(
        "[Convolution::eval] got unsupported data type.");
  }
}

void dispatch_slow_conv_3D(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  if (in.dtype() == float32) {
    return slow_conv_3D<float>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == float16) {
    return slow_conv_3D<float16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else if (in.dtype() == bfloat16) {
    return slow_conv_3D<bfloat16_t>(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        in_dilation,
        flip,
        stream);
  } else {
    throw std::invalid_argument(
        "[Convolution::eval] got unsupported data type.");
  }
}

///////////////////////////////////////////////////////////////////////////////
// Explicit gemm conv
///////////////////////////////////////////////////////////////////////////////

template <typename T>
void flip_spatial_dims_inplace(
    T* x,
    size_t in_channels,
    size_t out_channels,
    size_t spatial_size) {
  for (size_t i = 0; i < out_channels; i++) {
    T* top = x + i * spatial_size * in_channels;
    T* bottom =
        x + i * spatial_size * in_channels + (spatial_size - 1) * in_channels;
    for (size_t j = 0; j < spatial_size / 2; j++) {
      for (size_t k = 0; k < in_channels; k++) {
        std::swap(top[k], bottom[k]);
      }
      top += in_channels;
      bottom -= in_channels;
    }
  }
}

void explicit_gemm_conv_1D_cpu(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    Stream stream) {
  const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
  const int iH = in.shape(1); // Input spatial dim
  const int C = in.shape(2); // Input channels
  const int oH = out.shape(1); // Output spatial dim
  const int O = wt.shape(0); // Out channels
  const int wH = wt.shape(1); // Weight spatial dim

  const int groups = C / wt.shape(2);
  const int C_per_group = wt.shape(2);
  const int O_per_group = O / groups;

  auto conv_dtype = float32;
  auto& encoder = cpu::get_command_encoder(stream);

  // Pad input
  Shape padded_shape = {N, iH + padding_lo[0] + padding_hi[0], C};
  array in_padded(padded_shape, conv_dtype, nullptr, {});

  // Fill with zeros
  std::vector<array> temps;
  temps.push_back(array(0, conv_dtype));
  copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);

  // Pick input slice from padded
  size_t data_offset = padding_lo[0] * in_padded.strides()[1];
  array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
  in_padded_slice.copy_shared_buffer(
      in_padded,
      in_padded.strides(),
      in_padded.flags(),
      in_padded_slice.size(),
      data_offset);
  // Copy input values into the slice
  copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
  temps.push_back(in_padded_slice);

  // Make strided view
  Shape strided_shape = {N, oH, wH, C};

  Strides strided_strides = {
      in_padded.strides()[0],
      in_padded.strides()[1] * wt_strides[0],
      in_padded.strides()[1],
      in_padded.strides()[2]};
  auto flags = in_padded.flags();
  if (groups > 1) {
    // Transpose the last two dimensions for grouped convolutions
    std::swap(strided_shape[2], strided_shape[3]);
    std::swap(strided_strides[2], strided_strides[3]);
  }

  array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
  in_strided_view.copy_shared_buffer(
      in_padded, strided_strides, flags, in_strided_view.size(), 0);

  // Materialize strided view
  Shape strided_reshape = {N * oH, wH * C};
  array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
  copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
  temps.push_back(in_strided);

  // Check wt dtype and prepare
  auto gemm_wt = wt;
  auto gemm_out = out;

  if (groups > 1) {
    // Transpose the last two dimensions for grouped convolutions
    array wt_transpose(
        {wt.shape(0), wt.shape(2), wt.shape(1)}, wt.dtype(), nullptr, {});
    wt_transpose.copy_shared_buffer(
        wt,
        {wt.strides(0), wt.strides(2), wt.strides(1)},
        wt.flags(),
        wt.size(),
        0);
    gemm_wt = array(wt_transpose.shape(), float32, nullptr, {});
    copy_cpu(wt_transpose, gemm_wt, CopyType::General, stream);
    temps.push_back(gemm_wt);
  } else if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
    auto ctype =
        wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
    gemm_wt = array(wt.shape(), float32, nullptr, {});
    copy_cpu(wt, gemm_wt, ctype, stream);
    temps.push_back(gemm_wt);
  }

  if (out.dtype() != float32) {
    gemm_out = array(out.shape(), float32, nullptr, {});
    gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
    temps.push_back(gemm_out);
  }

  encoder.set_input_array(in_strided);
  encoder.set_input_array(gemm_wt);
  encoder.set_output_array(gemm_out);

  encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
                    gemm_wt_ptr = gemm_wt.data<float>(),
                    gemm_out_ptr = gemm_out.data<float>(),
                    groups,
                    strided_reshape = strided_reshape[0],
                    O,
                    C,
                    wH,
                    O_per_group,
                    C_per_group]() {
    for (int g = 0; g < groups; ++g) {
      // Perform gemm
      cblas_sgemm(
          CblasRowMajor,
          CblasNoTrans, // no trans A
          CblasTrans, // transB
          strided_reshape, // M
          O_per_group, // N
          C_per_group * wH, // K
          1.0f, // alpha
          in_strided_ptr + g * C_per_group * wH, // A
          wH * C, // lda
          gemm_wt_ptr + g * O_per_group * C_per_group * wH, // B
          wH * C_per_group, // ldb
          0.0f, // beta
          gemm_out_ptr + g * O_per_group, // C
          O // ldc
      );
    }
  });

  // Copy results if needed
  if (out.dtype() != float32) {
    copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
  }
  encoder.add_temporaries(std::move(temps));
}

void explicit_gemm_conv_ND_cpu(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const bool flip,
    Stream stream) {
  const int N = in.shape(0); // Batch size, should be the same as out.shape(0)
  const auto iDim =
      Shape(in.shape().begin() + 1, in.shape().end() - 1); // Input spatial dim
  const auto oDim = Shape(
      out.shape().begin() + 1, out.shape().end() - 1); // Output spatial dim
  const int O = wt.shape(0); // Out channels
  const int C = wt.shape(-1); // In channels
  const auto wDim =
      Shape(wt.shape().begin() + 1, wt.shape().end() - 1); // Weight spatial dim

  auto conv_dtype = float32;

  auto& encoder = cpu::get_command_encoder(stream);

  // Pad input
  Shape padded_shape(in.shape().size());
  padded_shape.front() = N;
  for (size_t i = 0; i < iDim.size(); i++) {
    padded_shape[i + 1] = iDim[i] + padding_lo[i] + padding_hi[i];
  }
  padded_shape.back() = C;
  array in_padded(padded_shape, conv_dtype, nullptr, {});

  // Fill with zeros
  std::vector<array> temps = {array(0, conv_dtype)};
  copy_cpu(temps.back(), in_padded, CopyType::Scalar, stream);

  // Pick input slice from padded
  size_t data_offset = 0;
  for (size_t i = 0; i < padding_lo.size(); i++) {
    data_offset += padding_lo[i] * in_padded.strides()[i + 1];
  }

  array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
  in_padded_slice.copy_shared_buffer(
      in_padded,
      in_padded.strides(),
      in_padded.flags(),
      in_padded_slice.size(),
      data_offset);

  // Copy input values into the slice
  copy_cpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, stream);
  temps.push_back(in_padded_slice);

  // Make strided view
  Shape strided_shape(oDim.size() + wDim.size() + 2);
  strided_shape.front() = N;
  for (size_t i = 0; i < oDim.size(); i++) {
    strided_shape[i + 1] = oDim[i];
  }
  for (size_t i = 0; i < wDim.size(); i++) {
    strided_shape[i + 1 + oDim.size()] = wDim[i];
  }
  strided_shape.back() = C;

  Strides strided_strides(in.shape().size() * 2 - 2);
  strided_strides[0] = in_padded.strides()[0];
  for (size_t i = 0; i < wt_strides.size(); i++) {
    strided_strides[i + 1] = in_padded.strides()[i + 1] * wt_strides[i];
  }
  for (size_t i = 1; i < in_padded.strides().size(); i++) {
    strided_strides[i + wt_strides.size()] = in_padded.strides()[i];
  }

  auto flags = in_padded.flags();

  array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {});
  in_strided_view.copy_shared_buffer(
      in_padded, strided_strides, flags, in_strided_view.size(), 0);

  // Materialize strided view
  Shape strided_reshape = {N, C};
  for (const auto& o : oDim) {
    strided_reshape[0] *= o;
  }
  for (const auto& w : wDim) {
    strided_reshape[1] *= w;
  }

  array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {});
  copy_cpu(in_strided_view, in_strided, CopyType::General, stream);
  temps.push_back(in_strided);

  // Check wt dtype and prepare
  auto gemm_wt = wt;
  auto gemm_out = out;

  if (wt.dtype() != float32 || !wt.flags().row_contiguous) {
    auto ctype =
        wt.flags().row_contiguous ? CopyType::Vector : CopyType::General;
    gemm_wt = array(wt.shape(), float32, nullptr, {});
    copy_cpu(wt, gemm_wt, ctype, stream);
    temps.push_back(gemm_wt);
  }

  if (flip) {
    auto gemm_wt_ = array(gemm_wt.shape(), float32, nullptr, {});
    copy_cpu(gemm_wt, gemm_wt_, CopyType::Vector, stream);
    temps.push_back(gemm_wt_);

    // Calculate the total size of the spatial dimensions
    int spatial_size = 1;
    for (int d = 1; d < gemm_wt.ndim() - 1; ++d) {
      spatial_size *= gemm_wt.shape(d);
    }
    encoder.set_output_array(gemm_wt_);
    encoder.dispatch([gemm_wt_ptr = gemm_wt_.data<float>(),
                      out_channels = gemm_wt.shape(0),
                      in_channels = gemm_wt.shape(-1),
                      spatial_size]() {
      flip_spatial_dims_inplace<float>(
          gemm_wt_ptr, in_channels, out_channels, spatial_size);
    });
    gemm_wt = gemm_wt_;
  }

  if (out.dtype() != float32) {
    gemm_out = array(out.shape(), float32, nullptr, {});
    gemm_out.set_data(allocator::malloc(gemm_out.nbytes()));
    temps.push_back(gemm_out);
  }

  encoder.set_input_array(in_strided);
  encoder.set_input_array(gemm_wt);
  encoder.set_output_array(gemm_out);

  encoder.dispatch([in_strided_ptr = in_strided.data<float>(),
                    gemm_wt_ptr = gemm_wt.data<float>(),
                    gemm_out_ptr = gemm_out.data<float>(),
                    strided_reshape = std::move(strided_reshape),
                    O]() {
    // Perform gemm
    cblas_sgemm(
        CblasRowMajor,
        CblasNoTrans, // no trans A
        CblasTrans, // transB
        strided_reshape[0], // M
        O, // N
        strided_reshape[1], // K
        1.0f, // alpha
        in_strided_ptr,
        strided_reshape[1], // lda
        gemm_wt_ptr,
        strided_reshape[1], // ldb
        0.0f, // beta
        gemm_out_ptr,
        O // ldc
    );
  });

  // Copy results if needed
  if (out.dtype() != float32) {
    copy_cpu_inplace(gemm_out, out, CopyType::Vector, stream);
  }
  encoder.add_temporaries(std::move(temps));
}

///////////////////////////////////////////////////////////////////////////////
// Conv routing
///////////////////////////////////////////////////////////////////////////////

void conv_1D_cpu(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  const int groups = in.shape().back() / wt.shape().back();
  if (wt_dilation[0] == 1 && in_dilation[0] == 1 && !flip) {
    return explicit_gemm_conv_1D_cpu(
        in, wt, out, padding_lo, padding_hi, wt_strides, wt_dilation, stream);
  }
  if (wt_dilation[0] == 1 && in_dilation[0] == 1 && groups == 1) {
    return explicit_gemm_conv_ND_cpu(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        flip,
        stream);
  }

  return dispatch_slow_conv_1D(
      in,
      wt,
      out,
      padding_lo,
      padding_hi,
      wt_strides,
      wt_dilation,
      in_dilation,
      flip,
      stream);
}

void conv_2D_cpu(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  const int groups = in.shape().back() / wt.shape().back();
  if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && in_dilation[0] == 1 &&
      in_dilation[1] == 1 && groups == 1) {
    return explicit_gemm_conv_ND_cpu(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        flip,
        stream);
  }
  return dispatch_slow_conv_2D(
      in,
      wt,
      out,
      padding_lo,
      padding_hi,
      wt_strides,
      wt_dilation,
      in_dilation,
      flip,
      stream);
}

void conv_3D_cpu(
    const array& in,
    const array& wt,
    array out,
    const std::vector<int>& padding_lo,
    const std::vector<int>& padding_hi,
    const std::vector<int>& wt_strides,
    const std::vector<int>& wt_dilation,
    const std::vector<int>& in_dilation,
    bool flip,
    Stream stream) {
  const int groups = in.shape().back() / wt.shape().back();
  if (wt_dilation[0] == 1 && wt_dilation[1] == 1 && wt_dilation[2] == 1 &&
      in_dilation[0] == 1 && in_dilation[1] == 1 && in_dilation[2] == 1 &&
      groups == 1) {
    return explicit_gemm_conv_ND_cpu(
        in,
        wt,
        out,
        padding_lo,
        padding_hi,
        wt_strides,
        wt_dilation,
        flip,
        stream);
  }

  return dispatch_slow_conv_3D(
      in,
      wt,
      out,
      padding_lo,
      padding_hi,
      wt_strides,
      wt_dilation,
      in_dilation,
      flip,
      stream);
}

} // namespace

void Convolution::eval_cpu(const std::vector<array>& inputs, array& out) {
  out.set_data(allocator::malloc(out.nbytes()));

  auto& in = inputs[0];
  auto& wt = inputs[1];

  // 3D convolution
  if (in.ndim() == (3 + 2)) {
    return conv_3D_cpu(
        in,
        wt,
        out,
        padding_lo_,
        padding_hi_,
        kernel_strides_,
        kernel_dilation_,
        input_dilation_,
        flip_,
        stream());
  }
  // 2D convolution
  else if (in.ndim() == (2 + 2)) {
    return conv_2D_cpu(
        in,
        wt,
        out,
        padding_lo_,
        padding_hi_,
        kernel_strides_,
        kernel_dilation_,
        input_dilation_,
        flip_,
        stream());
  }
  // 1D convolution
  else if (in.ndim() == (1 + 2)) {
    return conv_1D_cpu(
        in,
        wt,
        out,
        padding_lo_,
        padding_hi_,
        kernel_strides_,
        kernel_dilation_,
        input_dilation_,
        flip_,
        stream());
  }
  // Throw error
  else {
    std::ostringstream msg;
    msg << "[Convolution::eval] Convolution currently only supports"
        << " 1D, 2D and 3D convolutions. Got inputs with " << in.ndim() - 2
        << " spatial dimensions";
    throw std::invalid_argument(msg.str());
  }
}

} // namespace mlx::core
